package com.example.harmonet.harmtorch;

public class Argmax implements Layer {

    @Override
    public Tensor forward(Tensor in) throws Exception {
        Tensor output = new Tensor(1);
        int idx = 0;
        for(int i = 0; i < in.tensor().length; ++i) {
            if(in.tensor()[idx] < in.tensor()[i]) {
                idx = i;
            }
        }
        output.tensor()[0] = idx;
        return output;
    }

    @Override
    public int getParam() {
        return 0;
    }

    @Override
    public void init(float[] param) {
    }
}
